"""
Author: Yanrui Hu
Date: 2023/3/18
Description: 有若干节点，每个节点上有能量塔，所有节点构成一棵树...探究每个节点上的能量值
Keyword: 能量塔, 能量场, 能量扩散, Energy-Tower, Energy-Field, Energy-Radiation
Version: 1.0.0
Reference: CZY聊天记录

题目描述:
现在有若干节点，每个节点上有能量塔，所有节点构成一棵树。某个节点u可以为所有和u距离不超过给定值的节点各提供一点能力。此处距离的定义为两个节点之间经过的边的数量。特别的，节点u到本身的距离为0.现在给出每个节点上能量塔可以为多远的距离内的点提供能量。小美想要探究每个节点上的能量值具体是多少。你的任务是帮助小美计算得到，并以此输出。
输入描述：第一行一个整数N，表示节点的数量，接下来一行N个以空格分开的整数，以此表示节点1，节点2，...，节点N的能量塔所能提供的最远距离。接下来N-1行，每行两个整数，表示两个点之间有一条边。
输出：一行N个整数，以此表示每个节点上的能量值
"""

from typing import List, Tuple

"""
二叉树的最近公共祖先解法
class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        if root is None or root is p or root is q:
            return root

        left = self.lowestCommonAncestor(root.left, p, q)
        right = self.lowestCommonAncestor(root.right, p, q)

        if left and right:
            return root
        if left:
            return left
        return right """


def build_tree(n, edges):
    """
    通过 edges 构造树形结构, n 叉树
    Params:
     n: int, 表示节点（能量塔）的个数
     edges: List[Tuple[int, int]], 边的集合，表示两个点之间有一条边
    Return:
     包含每个节点的父亲的列表
    """
    parents = [0] * (n + 1)  # 节点的编号从1开始
    for i, j in edges:
        parents[j] = i
    return parents


def lca(parents, p, q) -> Tuple[int, int, int]:
    """
    寻找到 p, q 的最近公共祖先
    Params:
     parents: List[int]
     p, q: int, 节点编号
    Return:
     最近公共祖先的节点编号, 到p的距离, 到q的距离
    """
    p_ancestor = set()
    pp = p
    while pp:
        p_ancestor.add(pp)
        pp = parents[pp]

    ance = 0
    pq = q
    q_to_ance = 0  # q到ance的距离

    while pq:
        if pq in p_ancestor:
            ance = pq
            break
        else:
            pq = parents[pq]
            q_to_ance += 1

    pp = p
    p_to_ance = 0  # p到ance的距离
    while pp:
        if pp == ance:
            break
        else:
            pp = parents[pp]
            p_to_ance += 1

    return (ance, p_to_ance, q_to_ance)


def calc_energy(n, max_distance, edges) -> List[int]:
    '''
    计算每个节点的能量值
    Params:
     n: int, 表示节点（能量塔）的个数
     max_distance: List[int], 能量塔所能提供的最远距离
     edges: List[Tuple[int, int]], 边的集合，表示两个点之间有一条边
    Returns:
     List[int] 一行N个整数, 以此表示每个节点上的能量值
    '''
    energies = [0] * (n + 1)  # 节点的编号从1开始

    parents = build_tree(n, edges)
    # CASE 1
    # p, q = 3, 8
    dis_dic = {}
    for p in range(1, n+1):
        for q in range(p+1, n+1):
            ance, p_to_ance, q_to_ance = lca(parents, p, q)
            dis_dic[(p, q)] = p_to_ance + q_to_ance

    for i in range(1, n+1):
        for j in range(1, n+1):
            if i == j:
                energies[i] += 1
                continue

            dis_i_j = dis_dic[(i, j)] if i < j else dis_dic[(j, i)]
            if dis_i_j <= max_distance[j-1]:  # 如果第j个能量塔可以覆盖到 i所在的区域
                energies[i] += 1
    return energies


if __name__ == '__main__':
    # Handle input
    # n = int(input())

    # str_lis = input().split()
    # max_distance = list(map(int, str_lis))

    # edges = []
    # for i in range(n-1):
    #     str_tuple = tuple(input().split())
    #     edges.append(tuple(int(x) for x in str_tuple))

    #     print("edges: ", edges)

    # CASE 1
    print("*" * 8, "USE CASE 1", "*"*8)
    n = 3
    max_distance = [1, 1, 1]
    edges = [
        (1, 2),
        (2, 3)
    ]
    ANS = [2, 3, 2]

    # CASE 2
    print("*" * 8, "USE CASE 2", "*"*8)
    n = 10
    max_distance = list(map(int, "1 2 1 1 1 2 3 1 1 1".split()))
    edges = [
        (1, 2),
        (2, 3),
        (1, 4),
        (2, 5),
        (4, 6),
        (3, 7),
        (5, 8),
        (1, 9),
        (2, 10)]
    ANS = [6, 6, 3, 4, 4, 2, 3, 3, 3, 3]

    energies = calc_energy(n, max_distance, edges)
    print(energies[1:])
